Generative Adversarial Imitation Learning (GAIL) — low-level PyTorch#

GAIL is an imitation learning algorithm: it learns a policy (\pi_\theta(a\mid s)) from expert demonstrations without access to the expert’s reward function.

The core idea is adversarial training:

  • a discriminator (D_\phi(s,a)) tries to tell expert ((s,a)) pairs apart from policy-generated ((s,a)) pairs

  • the policy is trained to fool the discriminator, using a reward derived from (D_\phi)

This notebook implements a small but complete GAIL loop from scratch in PyTorch:

  • a toy 2D navigation environment (no Gym dependency)

  • a hand-coded expert to generate demonstrations

  • a discriminator network (D_\phi)

  • an actor-critic policy (\pi_\theta) trained with PPO using the discriminator reward

  • Plotly curves for discriminator loss, policy learning, and episodic rewards


Learning goals#

By the end you should be able to:

  • write down the GAIL min–max objective and explain the GAN analogy

  • implement a discriminator over ((s,a)) and train it with cross-entropy

  • turn discriminator outputs into a reward signal for RL

  • implement the PPO update (clipped objective + value loss + entropy bonus)

  • monitor training with Plotly: discriminator loss + episodic return

Notebook roadmap#

  1. GAIL objective + adversarial training equations

  2. A tiny offline-friendly environment + expert demonstrations

  3. Low-level PyTorch: policy/value networks + discriminator

  4. Training loop (alternate discriminator updates and policy PPO updates)

  5. Plotly diagnostics: discriminator loss, policy learning, episodic rewards

  6. Stable-Baselines GAIL notes + hyperparameters (end)

import time
import warnings

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

import torch
import torch.nn as nn
import torch.nn.functional as F


pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

# Some environments emit a CUDA-availability warning even when using CPU tensors.
warnings.filterwarnings("ignore", message=r"CUDA initialization:.*")

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

torch.set_num_threads(1)
DEVICE = torch.device("cpu")

np.set_printoptions(precision=4, suppress=True)
# --- Run configuration ---
FAST_RUN = True  # set False for longer training

# Environment
N_ENVS = 32 if FAST_RUN else 128
MAX_STEPS = 60
STEP_SIZE = 0.20
NOISE_STD = 0.00
GOAL_RADIUS = 0.12

# Expert dataset
N_EXPERT_EPISODES = 300 if FAST_RUN else 1500

# GAIL + PPO
ITERATIONS = 25 if FAST_RUN else 200
STEPS_PER_ITER = 128 if FAST_RUN else 512

GAMMA = 0.99
LAMBDA_GAE = 0.95

# Discriminator updates
D_LR = 3e-4
D_EPOCHS = 2 if FAST_RUN else 5
D_BATCH_SIZE = 512

# PPO updates
PI_LR = 3e-4
PPO_EPOCHS = 4 if FAST_RUN else 10
PPO_BATCH_SIZE = 1024
CLIP_EPS = 0.2
VF_COEF = 0.5
ENT_COEF = 0.01

# Eval
EVAL_EVERY = 5
EVAL_EPISODES = 200 if FAST_RUN else 1000

1) GAIL: adversarial training objective (equations)#

Setup#

You have expert demonstrations (state-action pairs) sampled from an expert policy (\pi_E):

[ (s,a) \sim \pi_E. ]

You want to learn a policy (\pi_\theta) that induces (approximately) the same occupancy measure as the expert.

Discriminator objective#

GAIL uses a discriminator (D_\phi(s,a)\in(0,1)) that outputs the probability that a ((s,a)) pair came from the expert. It is trained like a GAN discriminator:

[ \max_{\phi}; \mathbb{E}{(s,a)\sim \pi_E}[\log D\phi(s,a)]

  • \mathbb{E}{(s,a)\sim \pi\theta}[\log(1 - D_\phi(s,a))]. ]

Policy (generator) objective#

The policy plays the role of the GAN generator: it tries to produce ((s,a)) that the discriminator labels as expert. A common generator loss is:

[ \min_{\theta}; \mathbb{E}{(s,a)\sim \pi\theta}[\log(1 - D_\phi(s,a))] - \lambda,H(\pi_\theta), ]

where (H(\pi_\theta)) is the policy entropy (encourages exploration).

Turning the discriminator into a reward#

To train (\pi_\theta) with RL, we convert discriminator outputs into a reward:

[ \hat r_\phi(s,a) = -\log(1 - D_\phi(s,a)). ]

If the discriminator uses a logit (f_\phi(s,a)) (so (D_\phi = \sigma(f_\phi))), this reward has a numerically-stable form:

[ \hat r_\phi(s,a) = -\log(\sigma(-f_\phi(s,a))) = \mathrm{softplus}(f_\phi(s,a)). ]

Policy optimization (we’ll use PPO)#

We’ll optimize (\pi_\theta) with PPO. With advantage estimates (\hat A_t), PPO’s clipped objective is:

[ L^{\text{CLIP}}(\theta) = \mathbb{E}_t\Big[\min\big(r_t(\theta)\hat A_t,;\mathrm{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\hat A_t\big)\Big], ]

where (r_t(\theta)=\exp(\log\pi_\theta(a_t\mid s_t)-\log\pi_{\theta_{\text{old}}}(a_t\mid s_t))).

2) A tiny environment (no downloads, no Gym)#

We’ll use a simple 2D point navigation task:

  • observation (s = (x,y)\in[-1,1]^2)

  • 5 discrete actions: stay / up / down / left / right

  • start position is random

  • goal is the origin

  • episode ends when the point enters a goal radius or hits a step limit

We’ll generate expert demonstrations using a greedy hand-coded expert that always moves along the largest coordinate toward 0.

class VectorPointNav2D:
    def __init__(
        self,
        n_envs: int,
        max_steps: int = 60,
        step_size: float = 0.20,
        noise_std: float = 0.00,
        goal: tuple[float, float] = (0.0, 0.0),
        goal_radius: float = 0.12,
        seed: int = 0,
    ):
        self.n_envs = int(n_envs)
        self.max_steps = int(max_steps)
        self.step_size = float(step_size)
        self.noise_std = float(noise_std)
        self.goal = np.array(goal, dtype=np.float32)
        self.goal_radius = float(goal_radius)

        self.rng = np.random.default_rng(seed)
        self.pos = np.zeros((self.n_envs, 2), dtype=np.float32)
        self.t = np.zeros(self.n_envs, dtype=np.int32)

    @property
    def obs_dim(self) -> int:
        return 2

    @property
    def n_actions(self) -> int:
        # 0 stay, 1 up, 2 down, 3 left, 4 right
        return 5

    def reset(self) -> np.ndarray:
        self.t[:] = 0
        self.pos[:] = self.rng.uniform(low=-1.0, high=1.0, size=(self.n_envs, 2)).astype(np.float32)
        return self.pos.copy()

    def reset_done(self, done_mask: np.ndarray) -> None:
        idx = np.where(done_mask)[0]
        if len(idx) == 0:
            return
        self.t[idx] = 0
        self.pos[idx] = self.rng.uniform(low=-1.0, high=1.0, size=(len(idx), 2)).astype(np.float32)

    def step(self, actions: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
        actions = np.asarray(actions, dtype=np.int64)
        assert actions.shape == (self.n_envs,)

        move = np.zeros((self.n_envs, 2), dtype=np.float32)
        move[actions == 1, 1] = 1.0
        move[actions == 2, 1] = -1.0
        move[actions == 3, 0] = -1.0
        move[actions == 4, 0] = 1.0

        noise = self.rng.normal(loc=0.0, scale=self.noise_std, size=(self.n_envs, 2)).astype(np.float32)
        self.pos = np.clip(self.pos + self.step_size * move + noise, -1.0, 1.0)
        self.t += 1

        dist = np.linalg.norm(self.pos - self.goal[None, :], axis=1)
        success = dist < self.goal_radius
        done = success | (self.t >= self.max_steps)

        # True environment reward (for monitoring): small time penalty, big success bonus
        reward = -0.01 * np.ones(self.n_envs, dtype=np.float32)
        reward = reward + success.astype(np.float32) * 1.0

        info = {
            "dist": dist.astype(np.float32),
            "success": success.astype(np.bool_),
        }
        return self.pos.copy(), reward, done.astype(np.bool_), info


def expert_policy(obs: np.ndarray) -> np.ndarray:
    # Greedy expert: move along the largest coordinate toward 0.
    x = obs[:, 0]
    y = obs[:, 1]
    ax = np.abs(x)
    ay = np.abs(y)

    actions = np.zeros(len(obs), dtype=np.int64)
    choose_x = ax >= ay

    actions[choose_x & (x > 0)] = 3  # left
    actions[choose_x & (x < 0)] = 4  # right

    actions[(~choose_x) & (y > 0)] = 2  # down
    actions[(~choose_x) & (y < 0)] = 1  # up

    return actions
# Quick look at one expert trajectory

env_one = VectorPointNav2D(
    n_envs=1,
    max_steps=MAX_STEPS,
    step_size=STEP_SIZE,
    noise_std=NOISE_STD,
    goal_radius=GOAL_RADIUS,
    seed=SEED,
)
obs = env_one.reset()

traj = [obs[0].copy()]
actions = []
rewards = []

done = np.array([False])
while not done[0]:
    a = expert_policy(obs)[0]
    obs, r, done, info = env_one.step(np.array([a]))
    traj.append(obs[0].copy())
    actions.append(int(a))
    rewards.append(float(r[0]))

traj = np.stack(traj)

fig = go.Figure()
fig.add_trace(go.Scatter(x=traj[:, 0], y=traj[:, 1], mode="lines+markers", name="expert"))
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", marker=dict(size=12, symbol="x"), name="goal"))
fig.update_layout(title="One expert trajectory (2D point → origin)", xaxis_title="x", yaxis_title="y")
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.show()

print("steps:", len(actions), "episode_return:", sum(rewards), "success:", info["success"][0])
steps: 4 episode_return: 0.9600000102072954 success: True

3) Expert demonstrations dataset#

GAIL trains the discriminator on expert and policy ((s,a)) samples. We’ll generate a dataset of expert state-action pairs by running the expert in the environment.

def collect_expert_pairs(n_episodes: int, seed: int) -> tuple[np.ndarray, np.ndarray]:
    env = VectorPointNav2D(
        n_envs=1,
        max_steps=MAX_STEPS,
        step_size=STEP_SIZE,
        noise_std=NOISE_STD,
        goal_radius=GOAL_RADIUS,
        seed=seed,
    )

    obs_list: list[np.ndarray] = []
    act_list: list[int] = []

    for _ in range(int(n_episodes)):
        obs = env.reset()
        done = np.array([False])
        while not done[0]:
            a = int(expert_policy(obs)[0])
            obs_list.append(obs[0].copy())
            act_list.append(a)
            obs, r, done, info = env.step(np.array([a]))

    expert_obs = np.stack(obs_list).astype(np.float32)
    expert_acts = np.array(act_list, dtype=np.int64)
    return expert_obs, expert_acts


expert_obs, expert_acts = collect_expert_pairs(N_EXPERT_EPISODES, seed=SEED)
print("expert_obs", expert_obs.shape, "expert_acts", expert_acts.shape)

fig = px.histogram(x=expert_acts, nbins=5, title="Expert action histogram")
fig.update_layout(xaxis_title="action (0 stay, 1 up, 2 down, 3 left, 4 right)", yaxis_title="count")
fig.show()
expert_obs (2068, 2) expert_acts (2068,)

4) Low-level PyTorch: policy/value network and discriminator#

We’ll use:

  • Policy/value: a small shared MLP with two heads

    • policy head outputs categorical logits over 5 actions

    • value head outputs (V_\theta(s))

  • Discriminator: an MLP over ((s, \text{one-hot}(a))) returning a logit (f_\phi(s,a))

def one_hot(actions: torch.Tensor, n_actions: int) -> torch.Tensor:
    return F.one_hot(actions.long(), num_classes=n_actions).float()


class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, ...] = (64, 64)):
        super().__init__()
        layers: list[nn.Module] = []
        in_dim = obs_dim
        for h in hidden_sizes:
            layers += [nn.Linear(in_dim, h), nn.Tanh()]
            in_dim = h

        self.shared = nn.Sequential(*layers)
        self.pi = nn.Linear(in_dim, n_actions)
        self.v = nn.Linear(in_dim, 1)

    def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.shared(obs)
        logits = self.pi(x)
        value = self.v(x).squeeze(-1)
        return logits, value

    def get_action_and_value(
        self,
        obs: torch.Tensor,
        action: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        logits, value = self.forward(obs)
        dist = torch.distributions.Categorical(logits=logits)
        if action is None:
            action = dist.sample()
        logp = dist.log_prob(action)
        entropy = dist.entropy()
        return action, logp, entropy, value


class Discriminator(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, ...] = (128, 128)):
        super().__init__()
        in_dim = obs_dim + n_actions
        layers: list[nn.Module] = []
        for h in hidden_sizes:
            layers += [nn.Linear(in_dim, h), nn.Tanh()]
            in_dim = h
        layers += [nn.Linear(in_dim, 1)]

        self.net = nn.Sequential(*layers)
        self.n_actions = n_actions

    def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        x = torch.cat([obs, one_hot(actions, self.n_actions)], dim=-1)
        return self.net(x).squeeze(-1)  # logits


policy = ActorCritic(obs_dim=2, n_actions=5).to(DEVICE)
disc = Discriminator(obs_dim=2, n_actions=5).to(DEVICE)

pi_opt = torch.optim.Adam(policy.parameters(), lr=PI_LR)
d_opt = torch.optim.Adam(disc.parameters(), lr=D_LR)

5) Rollouts, GAE, discriminator update, PPO update#

We’ll collect policy rollouts from a vectorized environment, then alternate:

  1. Discriminator update(s) using expert pairs and current policy pairs

  2. Policy PPO update(s) using discriminator-derived rewards

We’ll use GAE((\gamma,\lambda)) for advantages.

def rollout(env: VectorPointNav2D, policy: ActorCritic, n_steps: int) -> dict:
    obs = env.reset()
    n_envs = env.n_envs

    obs_buf = []
    act_buf = []
    logp_buf = []
    val_buf = []
    done_buf = []
    true_r_buf = []

    ep_returns = np.zeros(n_envs, dtype=np.float32)
    completed_returns: list[float] = []
    completed_success: list[bool] = []

    for _ in range(int(n_steps)):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
        with torch.no_grad():
            action, logp, entropy, value = policy.get_action_and_value(obs_t)

        act_np = action.cpu().numpy()
        next_obs, true_r, done, info = env.step(act_np)

        obs_buf.append(obs.copy())
        act_buf.append(act_np.copy())
        logp_buf.append(logp.cpu().numpy())
        val_buf.append(value.cpu().numpy())
        done_buf.append(done.copy())
        true_r_buf.append(true_r.copy())

        ep_returns += true_r
        for i in range(n_envs):
            if done[i]:
                completed_returns.append(float(ep_returns[i]))
                completed_success.append(bool(info["success"][i]))
                ep_returns[i] = 0.0

        env.reset_done(done)
        obs = env.pos.copy()

    with torch.no_grad():
        _, last_values = policy.forward(torch.tensor(obs, dtype=torch.float32, device=DEVICE))

    return {
        "obs": np.asarray(obs_buf, dtype=np.float32),
        "actions": np.asarray(act_buf, dtype=np.int64),
        "logp": np.asarray(logp_buf, dtype=np.float32),
        "values": np.asarray(val_buf, dtype=np.float32),
        "dones": np.asarray(done_buf, dtype=np.bool_),
        "true_rewards": np.asarray(true_r_buf, dtype=np.float32),
        "last_values": last_values.cpu().numpy().astype(np.float32),
        "completed_returns": completed_returns,
        "completed_success": completed_success,
    }


def compute_gae(
    rewards: np.ndarray,
    values: np.ndarray,
    dones: np.ndarray,
    last_values: np.ndarray,
    gamma: float,
    lam: float,
) -> tuple[np.ndarray, np.ndarray]:
    # GAE-Lambda. Shapes: rewards/values/dones are (T, N). last_values is (N,).
    T, N = rewards.shape
    adv = np.zeros((T, N), dtype=np.float32)

    last_adv = np.zeros(N, dtype=np.float32)
    next_values = last_values.astype(np.float32)

    for t in reversed(range(T)):
        mask = 1.0 - dones[t].astype(np.float32)
        delta = rewards[t] + gamma * next_values * mask - values[t]
        last_adv = delta + gamma * lam * mask * last_adv
        adv[t] = last_adv
        next_values = values[t]

    returns = adv + values
    return adv, returns


def train_discriminator(
    disc: Discriminator,
    opt: torch.optim.Optimizer,
    expert_obs: np.ndarray,
    expert_acts: np.ndarray,
    gen_obs: np.ndarray,
    gen_acts: np.ndarray,
    epochs: int,
    batch_size: int,
) -> float:
    # BCE discriminator update. expert label=1, generator label=0.
    disc.train()

    n_gen = len(gen_obs)
    n_exp = len(expert_obs)
    n = min(n_gen, n_exp)

    idx_g = np.random.randint(0, n_gen, size=n)
    idx_e = np.random.randint(0, n_exp, size=n)

    g_obs = torch.tensor(gen_obs[idx_g], dtype=torch.float32, device=DEVICE)
    g_act = torch.tensor(gen_acts[idx_g], dtype=torch.int64, device=DEVICE)

    e_obs = torch.tensor(expert_obs[idx_e], dtype=torch.float32, device=DEVICE)
    e_act = torch.tensor(expert_acts[idx_e], dtype=torch.int64, device=DEVICE)

    losses: list[float] = []
    for _ in range(int(epochs)):
        perm = torch.randperm(n, device=DEVICE)
        for start in range(0, n, int(batch_size)):
            mb = perm[start : start + int(batch_size)]

            logits_g = disc(g_obs[mb], g_act[mb])
            logits_e = disc(e_obs[mb], e_act[mb])

            loss_g = F.binary_cross_entropy_with_logits(logits_g, torch.zeros_like(logits_g))
            loss_e = F.binary_cross_entropy_with_logits(logits_e, torch.ones_like(logits_e))
            loss = loss_g + loss_e

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            losses.append(float(loss.detach().cpu()))

    return float(np.mean(losses))


def ppo_update(
    policy: ActorCritic,
    opt: torch.optim.Optimizer,
    obs: np.ndarray,
    actions: np.ndarray,
    old_logp: np.ndarray,
    advantages: np.ndarray,
    returns: np.ndarray,
    clip_eps: float,
    vf_coef: float,
    ent_coef: float,
    epochs: int,
    batch_size: int,
) -> dict:
    policy.train()

    n = len(obs)
    obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
    act_t = torch.tensor(actions, dtype=torch.int64, device=DEVICE)
    old_logp_t = torch.tensor(old_logp, dtype=torch.float32, device=DEVICE)
    adv_t = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
    ret_t = torch.tensor(returns, dtype=torch.float32, device=DEVICE)

    adv_t = (adv_t - adv_t.mean()) / (adv_t.std() + 1e-8)

    total_losses: list[float] = []
    policy_losses: list[float] = []
    value_losses: list[float] = []
    entropies: list[float] = []
    approx_kls: list[float] = []

    for _ in range(int(epochs)):
        perm = torch.randperm(n, device=DEVICE)
        for start in range(0, n, int(batch_size)):
            mb = perm[start : start + int(batch_size)]

            _, logp, entropy, value = policy.get_action_and_value(obs_t[mb], act_t[mb])
            ratio = torch.exp(logp - old_logp_t[mb])

            pg1 = ratio * adv_t[mb]
            pg2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_t[mb]
            policy_loss = -torch.mean(torch.minimum(pg1, pg2))

            value_loss = F.mse_loss(value, ret_t[mb])
            entropy_bonus = torch.mean(entropy)

            loss = policy_loss + vf_coef * value_loss - ent_coef * entropy_bonus

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            opt.step()

            approx_kl = torch.mean(old_logp_t[mb] - logp).detach().cpu().item()

            total_losses.append(float(loss.detach().cpu()))
            policy_losses.append(float(policy_loss.detach().cpu()))
            value_losses.append(float(value_loss.detach().cpu()))
            entropies.append(float(entropy_bonus.detach().cpu()))
            approx_kls.append(float(approx_kl))

    return {
        "loss": float(np.mean(total_losses)),
        "policy_loss": float(np.mean(policy_losses)),
        "value_loss": float(np.mean(value_losses)),
        "entropy": float(np.mean(entropies)),
        "approx_kl": float(np.mean(approx_kls)),
    }

6) Train GAIL (alternate D and PPO updates)#

We’ll track:

  • discriminator loss

  • PPO diagnostics (policy/value/entropy/KL)

  • episodic returns from the true environment reward (monitoring only)

  • evaluation return + success rate

def gail_reward_from_logits(logits: torch.Tensor) -> torch.Tensor:
    # r = -log(1 - sigmoid(logits)) = softplus(logits)
    return F.softplus(logits)


def evaluate_policy(policy: ActorCritic, seed: int, n_episodes: int) -> dict:
    env = VectorPointNav2D(
        n_envs=1,
        max_steps=MAX_STEPS,
        step_size=STEP_SIZE,
        noise_std=NOISE_STD,
        goal_radius=GOAL_RADIUS,
        seed=seed,
    )

    returns: list[float] = []
    successes: list[bool] = []
    steps: list[int] = []

    for _ in range(int(n_episodes)):
        obs = env.reset()
        done = np.array([False])

        ep_return = 0.0
        ep_steps = 0
        ep_success = False

        while not done[0]:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
            with torch.no_grad():
                logits, _ = policy.forward(obs_t)
                action = torch.argmax(logits, dim=-1)

            obs, r, done, info = env.step(action.cpu().numpy())
            ep_return += float(r[0])
            ep_steps += 1
            ep_success = bool(info["success"][0])

        returns.append(ep_return)
        successes.append(ep_success)
        steps.append(ep_steps)

    return {
        "return_mean": float(np.mean(returns)),
        "return_std": float(np.std(returns)),
        "success_rate": float(np.mean(successes)),
        "steps_mean": float(np.mean(steps)),
    }


def evaluate_expert(seed: int, n_episodes: int) -> dict:
    env = VectorPointNav2D(
        n_envs=1,
        max_steps=MAX_STEPS,
        step_size=STEP_SIZE,
        noise_std=NOISE_STD,
        goal_radius=GOAL_RADIUS,
        seed=seed,
    )

    returns: list[float] = []
    successes: list[bool] = []
    steps: list[int] = []

    for _ in range(int(n_episodes)):
        obs = env.reset()
        done = np.array([False])

        ep_return = 0.0
        ep_steps = 0
        ep_success = False

        while not done[0]:
            a = int(expert_policy(obs)[0])
            obs, r, done, info = env.step(np.array([a]))
            ep_return += float(r[0])
            ep_steps += 1
            ep_success = bool(info["success"][0])

        returns.append(ep_return)
        successes.append(ep_success)
        steps.append(ep_steps)

    return {
        "return_mean": float(np.mean(returns)),
        "return_std": float(np.std(returns)),
        "success_rate": float(np.mean(successes)),
        "steps_mean": float(np.mean(steps)),
    }


env = VectorPointNav2D(
    n_envs=N_ENVS,
    max_steps=MAX_STEPS,
    step_size=STEP_SIZE,
    noise_std=NOISE_STD,
    goal_radius=GOAL_RADIUS,
    seed=SEED,
)

# Baseline: how good is the expert on this environment reward?
expert_eval = evaluate_expert(seed=SEED + 123, n_episodes=EVAL_EPISODES)
expert_eval
{'return_mean': 0.8834500104933977,
 'return_std': 0.322598197569875,
 'success_rate': 0.955,
 'steps_mean': 7.155}
disc_loss_hist: list[float] = []
ppo_loss_hist: list[float] = []
ppo_policy_loss_hist: list[float] = []
ppo_value_loss_hist: list[float] = []
ppo_entropy_hist: list[float] = []
ppo_kl_hist: list[float] = []

train_ep_returns: list[float] = []
train_ep_success: list[bool] = []
train_ep_iter: list[int] = []

eval_iters: list[int] = []
eval_return_mean: list[float] = []
eval_success_rate: list[float] = []

start = time.time()

for it in range(int(ITERATIONS)):
    data = rollout(env, policy, n_steps=STEPS_PER_ITER)

    # Flatten rollout buffers for discriminator/policy updates
    obs = data["obs"].reshape(-1, 2)
    acts = data["actions"].reshape(-1)
    old_logp = data["logp"].reshape(-1)

    # 1) Discriminator update
    dloss = train_discriminator(
        disc=disc,
        opt=d_opt,
        expert_obs=expert_obs,
        expert_acts=expert_acts,
        gen_obs=obs,
        gen_acts=acts,
        epochs=D_EPOCHS,
        batch_size=D_BATCH_SIZE,
    )

    # 2) Compute discriminator reward for the policy rollout
    disc.eval()
    with torch.no_grad():
        logits = disc(
            torch.tensor(obs, dtype=torch.float32, device=DEVICE),
            torch.tensor(acts, dtype=torch.int64, device=DEVICE),
        )
        gail_rewards = gail_reward_from_logits(logits).cpu().numpy().reshape(STEPS_PER_ITER, N_ENVS)

    # 3) PPO update using GAE on the discriminator reward
    adv, rets = compute_gae(
        rewards=gail_rewards,
        values=data["values"],
        dones=data["dones"],
        last_values=data["last_values"],
        gamma=GAMMA,
        lam=LAMBDA_GAE,
    )

    ppo_stats = ppo_update(
        policy=policy,
        opt=pi_opt,
        obs=obs,
        actions=acts,
        old_logp=old_logp,
        advantages=adv.reshape(-1),
        returns=rets.reshape(-1),
        clip_eps=CLIP_EPS,
        vf_coef=VF_COEF,
        ent_coef=ENT_COEF,
        epochs=PPO_EPOCHS,
        batch_size=PPO_BATCH_SIZE,
    )

    disc_loss_hist.append(dloss)
    ppo_loss_hist.append(ppo_stats["loss"])
    ppo_policy_loss_hist.append(ppo_stats["policy_loss"])
    ppo_value_loss_hist.append(ppo_stats["value_loss"])
    ppo_entropy_hist.append(ppo_stats["entropy"])
    ppo_kl_hist.append(ppo_stats["approx_kl"])

    # Record true episodic returns completed during this iteration
    for r, s in zip(data["completed_returns"], data["completed_success"]):
        train_ep_returns.append(float(r))
        train_ep_success.append(bool(s))
        train_ep_iter.append(it)

    # Evaluate periodically
    if (it + 1) % int(EVAL_EVERY) == 0:
        stats = evaluate_policy(policy, seed=SEED + 999 + it, n_episodes=EVAL_EPISODES)
        eval_iters.append(it + 1)
        eval_return_mean.append(stats["return_mean"])
        eval_success_rate.append(stats["success_rate"])

elapsed = time.time() - start
elapsed
3.971886157989502

7) Plotly diagnostics#

Required plots:

  • discriminator loss

  • policy learning (evaluation return + success rate)

  • episodic rewards (environment return per episode)

# Discriminator loss over iterations

df_disc = pd.DataFrame({
    "iteration": np.arange(1, len(disc_loss_hist) + 1),
    "disc_loss": disc_loss_hist,
})

fig = px.line(df_disc, x="iteration", y="disc_loss", title="Discriminator loss")
fig.update_layout(xaxis_title="iteration", yaxis_title="BCE loss")
fig.show()
# Policy learning: evaluation return + success rate

df_eval = pd.DataFrame({
    "iteration": eval_iters,
    "eval_return_mean": eval_return_mean,
    "eval_success_rate": eval_success_rate,
})

fig = go.Figure()
fig.add_trace(go.Scatter(x=df_eval["iteration"], y=df_eval["eval_return_mean"], mode="lines+markers", name="eval return"))
fig.add_trace(go.Scatter(x=df_eval["iteration"], y=df_eval["eval_success_rate"], mode="lines+markers", name="success rate", yaxis="y2"))

fig.update_layout(
    title="Policy learning (evaluation)",
    xaxis=dict(title="iteration"),
    yaxis=dict(title="mean episodic return"),
    yaxis2=dict(title="success rate", overlaying="y", side="right", range=[0, 1]),
)
fig.show()

print("Expert baseline:", expert_eval)
Expert baseline: {'return_mean': 0.8834500104933977, 'return_std': 0.322598197569875, 'success_rate': 0.955, 'steps_mean': 7.155}
# Episodic rewards collected during training

if len(train_ep_returns) == 0:
    print("No completed episodes recorded (increase STEPS_PER_ITER or MAX_STEPS).")
else:
    df_ep = pd.DataFrame({
        "episode": np.arange(1, len(train_ep_returns) + 1),
        "return": train_ep_returns,
        "success": train_ep_success,
        "iteration": train_ep_iter,
    })

    fig = px.scatter(
        df_ep,
        x="episode",
        y="return",
        color="success",
        title="Episodic returns during training (true env reward)",
        opacity=0.6,
    )
    fig.update_layout(xaxis_title="episode", yaxis_title="episodic return")
    fig.show()

    window = 25
    if len(df_ep) >= window:
        ma = df_ep["return"].rolling(window=window).mean()
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["return"], mode="markers", name="return", opacity=0.35))
        fig.add_trace(go.Scatter(x=df_ep["episode"], y=ma, mode="lines", name=f"moving avg (window={window})"))
        fig.update_layout(title="Episodic return with moving average", xaxis_title="episode", yaxis_title="return")
        fig.show()

8) Visual sanity check: expert vs learned trajectories#

A qualitative check: roll out the expert and the learned policy from the same start state.

def rollout_single(policy: ActorCritic, seed: int, use_expert: bool) -> dict:
    env = VectorPointNav2D(
        n_envs=1,
        max_steps=MAX_STEPS,
        step_size=STEP_SIZE,
        noise_std=NOISE_STD,
        goal_radius=GOAL_RADIUS,
        seed=seed,
    )
    obs = env.reset()
    traj = [obs[0].copy()]
    rewards = []

    done = np.array([False])
    while not done[0]:
        if use_expert:
            a = int(expert_policy(obs)[0])
        else:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
            with torch.no_grad():
                logits, _ = policy.forward(obs_t)
                a = int(torch.argmax(logits, dim=-1).cpu().item())

        obs, r, done, info = env.step(np.array([a]))
        traj.append(obs[0].copy())
        rewards.append(float(r[0]))

    return {
        "traj": np.stack(traj),
        "return": float(sum(rewards)),
        "success": bool(info["success"][0]),
    }


seed = SEED + 2025
expert_roll = rollout_single(policy, seed=seed, use_expert=True)
learned_roll = rollout_single(policy, seed=seed, use_expert=False)

fig = go.Figure()
fig.add_trace(go.Scatter(x=expert_roll["traj"][:, 0], y=expert_roll["traj"][:, 1], mode="lines+markers", name="expert"))
fig.add_trace(go.Scatter(x=learned_roll["traj"][:, 0], y=learned_roll["traj"][:, 1], mode="lines+markers", name="learned"))
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", marker=dict(size=12, symbol="x"), name="goal"))

fig.update_layout(
    title="Expert vs learned trajectory (same start)",
    xaxis_title="x",
    yaxis_title="y",
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.show()

print("expert", {k: expert_roll[k] for k in ["return", "success"]})
print("learned", {k: learned_roll[k] for k in ["return", "success"]})
expert {'return': 0.9600000102072954, 'success': True}
learned {'return': -0.5999999865889549, 'success': False}

9) Stable-Baselines GAIL (implementation exists) + hyperparameters#

Does Stable-Baselines implement GAIL?#

Yes: Stable-Baselines (the TensorFlow-based library, not SB3) ships a GAIL class. Upstream docs/source show:

  • stable_baselines.GAIL exists and is TRPO-based (inherits from TRPO)

  • it expects an ExpertDataset

  • it requires OpenMPI support (for the MPI-based TRPO implementation)

Example from upstream docs (Pendulum):

import gym

from stable_baselines import GAIL, SAC
from stable_baselines.gail import ExpertDataset, generate_expert_traj

# Generate expert trajectories (train expert)
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10)

# Load the expert dataset
dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)

model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
model.learn(total_timesteps=1000)
model.save("gail_pendulum")

Hyperparameters (Stable-Baselines GAIL)#

From the Stable-Baselines GAIL docstring/source, the key knobs are:

TRPO / policy optimization (inherited)

  • gamma: discount factor

  • timesteps_per_batch: rollout horizon per TRPO batch

  • max_kl: KL constraint threshold (trust-region size)

  • cg_iters: conjugate-gradient iterations (for TRPO step)

  • lam: GAE((\lambda))

  • entcoeff: entropy regularization coefficient

  • cg_damping: damping for conjugate gradient / Fisher-vector products

  • vf_stepsize: value function optimizer step size

  • vf_iters: value function training iterations per update

  • hidden_size: MLP hidden sizes for the policy/value network

GAIL-specific (how often/fast to train each player)

  • g_step: number of generator/policy steps per epoch

  • d_step: number of discriminator steps per epoch

  • d_stepsize: discriminator/reward-giver learning rate

  • hidden_size_adversary: discriminator hidden size

  • adversary_entcoeff: entropy term used in the adversary loss (stabilization)

Notes:

  • If d_step is too large (or d_stepsize too high), the discriminator can become too strong → sparse/unstable rewards.

  • If g_step is too large with a weak discriminator, the policy can overfit to a stale reward signal.

SB3 note#

Stable-Baselines3 does not ship GAIL in core; in practice, people often use the separate imitation library (HumanCompatibleAI) with SB3 policies.


Exercises#

  1. Replace PPO with a TRPO-style update (harder, but closer to the original paper).

  2. Add reward normalization (as Stable-Baselines’ adversary optionally does) and see how curves change.

  3. Make the environment continuous-action and switch the policy to a Gaussian distribution.


References#

  • Ho & Ermon (2016), Generative Adversarial Imitation Learning: https://arxiv.org/abs/1606.03476

  • Stable-Baselines GAIL docs: https://stable-baselines.readthedocs.io/en/master/modules/gail.html

  • Stable-Baselines source (stable_baselines/gail/model.py): https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/gail/model.py